#!/usr/bin/env python3
from __future__ import annotations
from rpi import logger
from rpi.helpers.data import flatten
from typing import List
import numpy as np


class NewStateDetector:
    def __init__(self):
        self._buffer = None
        pass

    def experience(self, all_episodes: List):
        """Expose the new-state detector to states.
        Args:
            - all_transitions (dict):
                -
        """
        # Example
        states = extract_states(all_episodes)
        self._buffer = states

    def evaluate(self, state: np.ndarray):
        """Evaluate the new-state detector on a single state

        Returns
            - a score (float) that represents the novelity of the state
        """
        return 0

    def batch_evaluate(self, batch_states: np.ndarray):
        """Evaluate batch of states.

        Returns
            - scores (np.ndarray)
        """
        scores = [self.evaluate(state) for state in batch_states]
        return scores


def extract_states(episodes: List[dict] | List[List[dict]]) -> np.ndarray:
    """
    Args:
        - single episode (List[dict]) or episodes (List[List[dict]]):
    Returns:
        - batch of states
    """
    if isinstance(episodes[0], dict):
        # single episode case
        transitions = episodes
    elif isinstance(episodes[0], list):
        # multiple episodes case
        transitions = flatten(episodes)
    else:
        raise TypeError("episodes must be List[dict] or List[List[dict]]")

    return np.asarray([tr["state"] for tr in transitions])


def simple_make_env(env_name, test: bool = False, default_seed: int = 0):
    from rpi.scripts.train import _make_env

    make_env = lambda *args, **kwargs: _make_env(
        env_name, test, default_seed=default_seed
    )
    test_env = make_env()

    state_dim = test_env.observation_space.low.size
    act_dim = test_env.action_space.low.size
    logger.info("env", env_name)
    logger.info("obs_dim", state_dim)
    logger.info("act_dim", act_dim)
    return make_env, state_dim, act_dim, test_env.unwrapped.spec.id.lower()
